"""
clean raw data from PatentsView to generate patent embeddings
for data schema, see https://patentsview.org/download/data-download-dictionary
remember to make change if needed for the Configs class
based on 2022-12-27

example use:
python3 clean_raw_patent.py
"""

import pandas as pd
import numpy as np
import os
import glob
import tqdm
import argparse
import logging


# lazy way of configs
class Configs:
    def __init__(self):
        self.min_year = 1976  # min year for patents
        self.max_year = 2015  # max year for patents
        self.PATENT_PATH = os.path.join(WORK_DIR, "raw_data/", "g_patent.tsv.zip")
        self.SAVE_PATH = os.path.join(WORK_DIR, "patentsberta")
        self.patent_cols = [
            "patent_id",
            "patent_date",
            "patent_title",
            "patent_abstract",
            "withdrawn",
        ]


def load_patent(
    PATENT_PATH: str, cols_keep: list, min_year: int, max_year: int
) -> pd.DataFrame:
    """
    read g_patent.tsv data from raw zipped file from PatentsView
    """
    df = pd.read_csv(
        PATENT_PATH,
        usecols=[
            "patent_id",
            "patent_date",
            "patent_title",
            "patent_abstract",
            "withdrawn",
        ],
        sep="\t",
    )
    df.loc[:, "patent_date"] = df["patent_date"].str[:4].astype("int16")
    df = df[
        (df["patent_date"] <= max_year)
        & (df["patent_date"] >= min_year)
        & (df["withdrawn"] == 0)
    ]
    df.drop(columns=["withdrawn"], inplace=True)
    df.rename(columns={"patent_date": "patent_year"}, inplace=True)
    logger.info(f"Number of patents between year {min_year} and {max_year}: {len(df)}")
    df.to_csv(
        os.path.join(WORK_DIR, "patentsberta/", "patent_raw.csv"), index=False
    )
    return df


def main():
    config = Configs()
    df_patent = load_patent(
        config.PATENT_PATH, config.patent_cols, config.min_year, config.max_year
    )


if __name__ == "__main__":
    WORK_DIR = "/Volumes/Zihao_SSD2/PatentsView/"
    logging.basicConfig(
        format="%(asctime)s:%(levelname)s:%(message)s", level=logging.INFO
    )
    logger = logging.getLogger(__name__)

    main()
